iT邦幫忙

2025 iThome 鐵人賽

DAY 5
0
AI & Data

實戰派 AI 工程師帶你 0->1系列 第 5

Day5: self attention 實作

  • 分享至 

  • xImage
  •  

前情提要

昨天花了很多時間在介紹 embedding 跟弱化的 softmax(XX^T)X,如果昨天了解個大概那今天就不太會有問題。

文章參考及圖片來源: https://www.cnblogs.com/rossiXYZ/p/18751758 , https://zhuanlan.zhihu.com/p/410776234

1. Self attention

核心觀念: 加權求和 → 根據目前的位置去關注序列中的其他位置,來取得有用的資訊。
後續延伸回到 Attention(Q, K, V) 公式

https://ithelp.ithome.com.tw/upload/images/20250830/20168446LZN7zjq5lZ.png
昨天有說到 Q, K, V 本質上都是由 x 轉換而來的,這個轉換是透過線性轉換 (程式對應 nn.Linear),以下用兩張圖來說明
  https://ithelp.ithome.com.tw/upload/images/20250830/20168446mYYT7FnbG3.png
  https://ithelp.ithome.com.tw/upload/images/20250830/20168446irwZDHjDkT.jpg
  圖片來源: https://www.cnblogs.com/hbuwyg/p/16978264.html
Q: 為什麼不直接使用 X 而要透過線性變換呢 ??
A: 主要是為了模型擬合程度,W矩陣是可以訓練,可以提高模型能力

Q: Q, K, V 分別代表什麼意思 ??
A:

  • Q 指的是 query, query 向量代表目前正在處理的 token 或位置的需求資訊。 (在序列中想找誰跟我最相關)
  • K 指的是 key,key 向量代表序列中每個 token 的標籤,用來和 Query 做比較,判斷"相關程度"。
  • V 指的是 value,value 向量代表每個 token 的實際資訊或特徵,會乘上權重來產生輸出。

2. 實作

這裡我們可以先看一下大家都是怎麼取名的,大致上會分成以下四種

  • q, k, v
  • query, key, value
  • linear_q, linear_k, linear_v
  • q_proj, k_proj, v_proj

另外維度大小通常是以下名稱

  • dims
  • d_model
  • hidden_size
  • embed_dim, embed_size
  • n_feat

這裡我們照幾個步驟完成,可以照著步驟一起想,先試著寫寫看,如果不行就先照打,確定可以之後再重來一次,這是我認為學最快的方式。
以下名稱採用 linear_q, 及 hidden_size

  1. 定義最基本的 class (init + forward) → 問自己 x 輸入的維度是多少
  2. Q, K, V 都是由 x 經過線性變化來的 → 所以要定義三個 nn.Linear(hidden_size, hidden_size)
    分母的 dk 上面沒有特別描述,就是一個 scaling → hidden_size ** -0.5
  3. 在 forward 準備要做計算
    1. x 做線性轉換 → query, key, value
    2. qk 內積 → torch.matmul, 乘以 scaling
    3. softmax → 得到 attn_weights
    4. 乘以 value → 得到最終 output
# step 1
import torch
from torch import nn
import torch.nn.functional as F

class MySelfAttention(nn.Module):
    def __init__(self):
        super().__init__()    

    def forward(self, x: torch.Tensor):
        '''
            B: batch size
            L: seq len
            D: embedding dimension
            x: (B, L, D) or (B, L, E) 簡寫每個都不太一樣
        '''
        return

# step 2
import torch
from torch import nn
import torch.nn.functional as F

class MySelfAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()    

        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_v = nn.Linear(hidden_size, hidden_size)

        self.scaling = hidden_size ** -0.5

    def forward(self, x: torch.Tensor):
        '''
            B: batch size
            L: seq len
            D: embedding dimension
            x: (B, L, D) or (B, L, E) 簡寫每個都不太一樣
        '''
        return
# step 3
import torch
from torch import nn
import torch.nn.functional as F

class MySelfAttention(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()    

        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_v = nn.Linear(hidden_size, hidden_size)

        self.scaling = hidden_size ** -0.5

    def forward(self, x: torch.Tensor):
        '''
            B: batch size
            L: seq len
            D: embedding dimension
            x: (B, L, D) or (B, L, E) 簡寫每個都不太一樣
        '''
        query = self.linear_q(x)
        key = self.linear_k(x)
        value = self.linear_v(x)

        # (B, L, D) dot (B, D, L) = (B, L, L)
        attn_scores  = torch.matmul(query, key.permute(0, 2, 1)) * self.scaling
        attn_weights = F.softmax(attn_weights, dim = -1)

        # (B, L, L) dot (B, L, D) = (B, L, D)
        attn_output = torch.matmul(attn_weights, value)

        return attn_output
        
if __name__ == "__main__":
    model = MySelfAttention(64)
    x = torch.rand(2, 100, 64)
    y = model(x)
    print(y.shape)

今天就以程式為主,花了點時間分成多個步驟,希望能讓你更好理解和實作,今天練完之後確定會,可以明天看著公式自己嘗試一次,相信沒多久你就會上手了,今天就先到這囉 ~~ 明天會有個小總結,來幫助你更了解。


上一篇
Day4: embedding & attention 觀念
下一篇
Day6: self attention 總結 & MHA 觀念
系列文
實戰派 AI 工程師帶你 0->18
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言